import pandas as pd
import numpy as np
from numba import njit
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LogisticRegression as Mnl
from sklearn.linear_model import LinearRegression as Ols
import itertools


def setup_data(state, mnl, contracts, polynomial_order, contracts_have_state=True):
    # Setup state names
    state_names = ['g', 'n_l', 'n_m', 'n_h']

    # Do polynomial data
    poly = PolynomialFeatures(polynomial_order,
                              include_bias=False)

    # Get data in the polynomial form
    poly_fit = poly.fit_transform(state[state_names])
    poly_names = poly.get_feature_names(state_names)

    state_poly = pd.DataFrame(poly_fit,
                              columns=poly_names,
                              index=state['date'])

    mnl_poly = mnl.merge(state_poly,
                         on='date',
                         how='left')

    if contracts_have_state:
        contracts_poly = (
            contracts
            .drop(columns=state_names)
            .merge(
                state_poly,
                on='date',
                how='left')
        )
    else:
        contracts_poly = (
            contracts
            .merge(
                state_poly,
                on='date',
                how='left')
        )

    return contracts_poly, mnl_poly, poly_names, poly

def estimate_mnl(mnl_poly, poly_names, poly):
    """ Estimates a multinomial logit model.

    Args:
        mnl_poly: data in the form for multinomial logit regression
            that includes polynomial state names.
        poly_names: the names of the polynomial state columns in
            the dataframe.
        poly: the polynomial class

    Returns:
        output: a dictionary of dictionaries of functions that fit
            the multinomial logit to the data. This is in the
            form of:
                output[rig spec]['tau']

    """
    reg = dict()
    df = dict()
    output = {'low': dict(),
              'mid': dict(),
              'high': dict()}

    for i, j in itertools.product(['tau'],
                                  ['low',
                                   'mid',
                                   'high']):
        def reg_fn_1(i, j):
            """ This is a closure wrapper to store the values i and
            j.

            Args:
                i: ['tau']
                j: rig type in ['low','mid','high']

            Returns:
                out: a function that stores the i and j and builds
                    a function that will fit the multinomial logit
                    model to state data.

            """
            # Build things for the multinomial logit
            mnl_spec = mnl_poly[(mnl_poly['rig_spec'] == j)]
            # mnl_spec[poly_names] = mnl_spec[poly_names]
            df[i + ' ' + j] = mnl_spec.dropna(subset=[i] + poly_names)

            mnl = Mnl(random_state=0,
                      multi_class='multinomial',
                      solver='newton-cg',
                      max_iter=10000)

            # fn_1 saves the following function:
            reg[i + ' ' + j] = mnl.fit(df[i + ' ' + j][poly_names],
                                       df[i + ' ' + j][i])

            def fn_1(s, vector=False):
                """ Transform a state to probabilities of matching
                each type of contract length (note that the prob
                of matching no contract can be computed as
                1 - sum(probs of matching)

                Args:
                    s: array of states with columns [g,n_l,n_m,n_h]

                Returns:
                    out: array of probability of matching each type
                    of contract length in [2,3,4].

                """
                if vector is False:
                    fit = poly.fit_transform(np.array(s).reshape(1, -1))
                if vector is True:
                    fit = poly.fit_transform(s)
                out = reg[i + ' ' + j].predict_proba(fit)[:, 1:4][0]
                return out

            return fn_1

        output[j][i] = reg_fn_1(i, j)

    return output, reg


def estimate_ols(contracts_poly, poly_names, poly,
                 column_names=['mri', 'value', 'day_rate']):
    """ Estimates a linear model.

    Args:
        contracts_poly: data in the form for a linear regression
            that includes polynomial state names.
        poly_names: the names of the polynomial state columns in
            the dataframe.
        poly: the polynomial class

    Returns:
        output: a dictionary of dictionaries of functions that fit
            the multinomial logit to the data. This is in the
            form of:

                output[rig spec][type]

            where type is in: [mri,value,dayrate]

    """
    reg = dict()
    output = {'low': dict(),
              'mid': dict(),
              'high': dict()}

    for i, j in itertools.product(column_names,
                                  ['low',
                                   'mid',
                                   'high']):
        def reg_fn_1(i, j):
            """ This is a closure wrapper to store the values i and
            j.

            Args:
                i: [mri,value,dayrate]
                j: rig type in ['low','mid','high']

            Returns:
                out: a function that stores the i and j and builds
                    a function that will fit the linear regression
                    model to state data.

            """
            mask = (contracts_poly['spec'] == j)
            df = contracts_poly[mask].dropna(subset=[i] + poly_names + ['tau'])

            # reg is the object that fn_1 saves
            reg[i + ' ' + j] = Ols().fit(df[poly_names + ['tau']],
                                         df[i])

            def fn_1(s, tau):
                """ Function that does the fitting.

                Args:
                    tau:
                    s: array of states with columns [g,n_l,n_m,n_h]

                Returns:
                    out: a function that fits the model.

                """
                if len(s.shape) == 1:
                    s = s.reshape(-1, 1).T
                    fit = poly.fit_transform(s)
                    x = np.append(
                        fit,
                        np.array([[tau]]),
                        axis=1
                    )
                else:
                    fit = poly.fit_transform(s)
                    x = np.append(
                        fit,
                        np.repeat([tau], len(s)).reshape(-1, 1),
                        axis=1
                    )

                prediction = reg[i + ' ' + j].predict(x)

                # Ensure prediction is > 0
                prediction = prediction.clip(min=0)
                return prediction

            return fn_1, reg

        output[j]['mean_' + i] = reg_fn_1(i, j)

    return output, reg



def build_ols_predictor(coefs):

    @njit
    def ols_predictor(state, tau):
        g = state[0]
        n_l = state[1]
        n_m = state[2]
        n_h = state[3]
        x = (
            coefs[0]
            + coefs[1] * g
            + coefs[2] * n_l
            + coefs[3] * n_m
            + coefs[4] * n_h
            + coefs[5] * g * g
            + coefs[6] * g * n_l
            + coefs[7] * g * n_m
            + coefs[8] * g * n_h
            + coefs[9] * n_l * n_l
            + coefs[10] * n_l * n_m
            + coefs[11] * n_l * n_h
            + coefs[12] * n_m * n_m
            + coefs[13] * n_m * n_h
            + coefs[14] * n_h * n_h
            + coefs[15] * tau
        )
        return x
    return ols_predictor


def build_mnl_predictor(coefs):

    @njit
    def mnl_predictor(state):
        g = state[0]
        n_l = state[1]
        n_m = state[2]
        n_h = state[3]

        x = np.array([0.0, 0.0, 0.0, 0.0])
        for i in [0, 1, 2, 3]:
            x[i] = (
                coefs[i, 0]
                + coefs[i, 1] * g
                + coefs[i, 2] * n_l
                + coefs[i, 3] * n_m
                + coefs[i, 4] * n_h
                + coefs[i, 5] * g * g
                + coefs[i, 6] * g * n_l
                + coefs[i, 7] * g * n_m
                + coefs[i, 8] * g * n_h
                + coefs[i, 9] * n_l * n_l
                + coefs[i, 10] * n_l * n_m
                + coefs[i, 11] * n_l * n_h
                + coefs[i, 12] * n_m * n_m
                + coefs[i, 13] * n_m * n_h
                + coefs[i, 14] * n_h * n_h
            )

        # NORMALIZATION
        #x = x - np.max(x)

        denom = np.exp(x[0]) + np.exp(x[1]) + np.exp(x[2]) + np.exp(x[3])

        return np.array([
            np.exp(x[0]) / denom,
            np.exp(x[1]) / denom,
            np.exp(x[2]) / denom,
            np.exp(x[3]) / denom
        ])
    return mnl_predictor


def build_prob_extension_predictor(coefs):

    @njit
    def prob_extension_predictor(state, tau):
        g = state[0]
        n_l = state[1]
        n_m = state[2]
        n_h = state[3]

        x = (
            coefs[0]
            + coefs[1] * g
            + coefs[2] * n_l
            + coefs[3] * n_m
            + coefs[4] * n_h
            + coefs[5] * g * g
            + coefs[6] * g * n_l
            + coefs[7] * g * n_m
            + coefs[8] * g * n_h
            + coefs[9] * n_l * n_l
            + coefs[10] * n_l * n_m
            + coefs[11] * n_l * n_h
            + coefs[12] * n_m * n_m
            + coefs[13] * n_m * n_h
            + coefs[14] * n_h * n_h
            + coefs[15] * tau
        )

        # NORMALIZATION
        #x = x - np.max(x)

        return np.exp(x) / (1 + np.exp(x))

    return prob_extension_predictor


def build_prob_extension_predictor_with_mri(coefs):

    @njit
    def prob_extension_predictor_with_mri(state, tau, mri):
        g = state[0]
        n_l = state[1]
        n_m = state[2]
        n_h = state[3]

        x = (
            coefs[0]
            + coefs[1] * g
            + coefs[2] * n_l
            + coefs[3] * n_m
            + coefs[4] * n_h
            + coefs[5] * g * g
            + coefs[6] * g * n_l
            + coefs[7] * g * n_m
            + coefs[8] * g * n_h
            + coefs[9] * n_l * n_l
            + coefs[10] * n_l * n_m
            + coefs[11] * n_l * n_h
            + coefs[12] * n_m * n_m
            + coefs[13] * n_m * n_h
            + coefs[14] * n_h * n_h
            + coefs[15] * tau
            + coefs[16] * mri
        )

        # NORMALIZATION
        #x = x - np.max(x)

        return np.exp(x) / (1 + np.exp(x))

    return prob_extension_predictor_with_mri


def build_price_extension_predictor(coefs):

    @njit
    def price_extension_predictor(state, tau):
        g = state[0]
        n_l = state[1]
        n_m = state[2]
        n_h = state[3]

        x = (
            coefs[0]
            + coefs[1] * g
            + coefs[2] * n_l
            + coefs[3] * n_m
            + coefs[4] * n_h
            + coefs[5] * g * g
            + coefs[6] * g * n_l
            + coefs[7] * g * n_m
            + coefs[8] * g * n_h
            + coefs[9] * n_l * n_l
            + coefs[10] * n_l * n_m
            + coefs[11] * n_l * n_h
            + coefs[12] * n_m * n_m
            + coefs[13] * n_m * n_h
            + coefs[14] * n_h * n_h
            + coefs[15] * tau
        )

        # NORMALIZATION
        #x = x - np.max(x)

        return x

    return price_extension_predictor
